#if !BESTHTTP_DISABLE_ALTERNATE_SSL && (!UNITY_WEBGL || UNITY_EDITOR) #pragma warning disable using System; using System.IO; using BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities; using BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.IO; namespace BestHTTP.SecureProtocol.Org.BouncyCastle.Crypto.Tls { /// <summary>An implementation of the TLS 1.0/1.1/1.2 record layer, allowing downgrade to SSLv3.</summary> internal sealed class RecordStream { private const int DEFAULT_PLAINTEXT_LIMIT = (1 << 14); internal const int TLS_HEADER_SIZE = 5; internal const int TLS_HEADER_TYPE_OFFSET = 0; internal const int TLS_HEADER_VERSION_OFFSET = 1; internal const int TLS_HEADER_LENGTH_OFFSET = 3; private TlsProtocol mHandler; private Stream mInput; private Stream mOutput; private TlsCompression mPendingCompression = null, mReadCompression = null, mWriteCompression = null; private TlsCipher mPendingCipher = null, mReadCipher = null, mWriteCipher = null; private SequenceNumber mReadSeqNo = new SequenceNumber(), mWriteSeqNo = new SequenceNumber(); private MemoryStream mBuffer = new MemoryStream(); private TlsHandshakeHash mHandshakeHash = null; private readonly BaseOutputStream mHandshakeHashUpdater; private ProtocolVersion mReadVersion = null, mWriteVersion = null; private bool mRestrictReadVersion = true; private int mPlaintextLimit, mCompressedLimit, mCiphertextLimit; internal RecordStream(TlsProtocol handler, Stream input, Stream output) { this.mHandler = handler; this.mInput = input; this.mOutput = output; this.mReadCompression = new TlsNullCompression(); this.mWriteCompression = this.mReadCompression; this.mHandshakeHashUpdater = new HandshakeHashUpdateStream(this); } internal /*virtual*/ void Init(TlsContext context) { this.mReadCipher = new TlsNullCipher(context); this.mWriteCipher = this.mReadCipher; this.mHandshakeHash = new DeferredHash(); this.mHandshakeHash.Init(context); SetPlaintextLimit(DEFAULT_PLAINTEXT_LIMIT); } internal /*virtual*/ int GetPlaintextLimit() { return mPlaintextLimit; } internal /*virtual*/ void SetPlaintextLimit(int plaintextLimit) { this.mPlaintextLimit = plaintextLimit; this.mCompressedLimit = this.mPlaintextLimit + 1024; this.mCiphertextLimit = this.mCompressedLimit + 1024; } internal /*virtual*/ ProtocolVersion ReadVersion { get { return mReadVersion; } set { this.mReadVersion = value; } } internal /*virtual*/ void SetWriteVersion(ProtocolVersion writeVersion) { this.mWriteVersion = writeVersion; } /** * RFC 5246 E.1. "Earlier versions of the TLS specification were not fully clear on what the * record layer version number (TLSPlaintext.version) should contain when sending ClientHello * (i.e., before it is known which version of the protocol will be employed). Thus, TLS servers * compliant with this specification MUST accept any value {03,XX} as the record layer version * number for ClientHello." */ internal /*virtual*/ void SetRestrictReadVersion(bool enabled) { this.mRestrictReadVersion = enabled; } internal /*virtual*/ void SetPendingConnectionState(TlsCompression tlsCompression, TlsCipher tlsCipher) { this.mPendingCompression = tlsCompression; this.mPendingCipher = tlsCipher; } internal /*virtual*/ void SentWriteCipherSpec() { if (mPendingCompression == null || mPendingCipher == null) throw new TlsFatalAlert(AlertDescription.handshake_failure); this.mWriteCompression = this.mPendingCompression; this.mWriteCipher = this.mPendingCipher; this.mWriteSeqNo = new SequenceNumber(); } internal /*virtual*/ void ReceivedReadCipherSpec() { if (mPendingCompression == null || mPendingCipher == null) throw new TlsFatalAlert(AlertDescription.handshake_failure); this.mReadCompression = this.mPendingCompression; this.mReadCipher = this.mPendingCipher; this.mReadSeqNo = new SequenceNumber(); } internal /*virtual*/ void FinaliseHandshake() { if (mReadCompression != mPendingCompression || mWriteCompression != mPendingCompression || mReadCipher != mPendingCipher || mWriteCipher != mPendingCipher) { throw new TlsFatalAlert(AlertDescription.handshake_failure); } this.mPendingCompression = null; this.mPendingCipher = null; } internal /*virtual*/ void CheckRecordHeader(byte[] recordHeader) { byte type = TlsUtilities.ReadUint8(recordHeader, TLS_HEADER_TYPE_OFFSET); /* * RFC 5246 6. If a TLS implementation receives an unexpected record type, it MUST send an * unexpected_message alert. */ CheckType(type, AlertDescription.unexpected_message); if (!mRestrictReadVersion) { int version = TlsUtilities.ReadVersionRaw(recordHeader, TLS_HEADER_VERSION_OFFSET); if ((version & 0xffffff00) != 0x0300) throw new TlsFatalAlert(AlertDescription.illegal_parameter); } else { ProtocolVersion version = TlsUtilities.ReadVersion(recordHeader, TLS_HEADER_VERSION_OFFSET); if (mReadVersion == null) { // Will be set later in 'readRecord' } else if (!version.Equals(mReadVersion)) { throw new TlsFatalAlert(AlertDescription.illegal_parameter); } } int length = TlsUtilities.ReadUint16(recordHeader, TLS_HEADER_LENGTH_OFFSET); CheckLength(length, mCiphertextLimit, AlertDescription.record_overflow); } internal /*virtual*/ bool ReadRecord() { byte[] recordHeader = TlsUtilities.ReadAllOrNothing(TLS_HEADER_SIZE, mInput); if (recordHeader == null) return false; byte type = TlsUtilities.ReadUint8(recordHeader, TLS_HEADER_TYPE_OFFSET); /* * RFC 5246 6. If a TLS implementation receives an unexpected record type, it MUST send an * unexpected_message alert. */ CheckType(type, AlertDescription.unexpected_message); if (!mRestrictReadVersion) { int version = TlsUtilities.ReadVersionRaw(recordHeader, TLS_HEADER_VERSION_OFFSET); if ((version & 0xffffff00) != 0x0300) throw new TlsFatalAlert(AlertDescription.illegal_parameter); } else { ProtocolVersion version = TlsUtilities.ReadVersion(recordHeader, TLS_HEADER_VERSION_OFFSET); if (mReadVersion == null) { mReadVersion = version; } else if (!version.Equals(mReadVersion)) { throw new TlsFatalAlert(AlertDescription.illegal_parameter); } } int length = TlsUtilities.ReadUint16(recordHeader, TLS_HEADER_LENGTH_OFFSET); CheckLength(length, mCiphertextLimit, AlertDescription.record_overflow); byte[] plaintext = DecodeAndVerify(type, mInput, length); mHandler.ProcessRecord(type, plaintext, 0, plaintext.Length); return true; } internal /*virtual*/ byte[] DecodeAndVerify(byte type, Stream input, int len) { byte[] buf = TlsUtilities.ReadFully(len, input); long seqNo = mReadSeqNo.NextValue(AlertDescription.unexpected_message); byte[] decoded = mReadCipher.DecodeCiphertext(seqNo, type, buf, 0, buf.Length); CheckLength(decoded.Length, mCompressedLimit, AlertDescription.record_overflow); /* * TODO 5246 6.2.2. Implementation note: Decompression functions are responsible for * ensuring that messages cannot cause internal buffer overflows. */ Stream cOut = mReadCompression.Decompress(mBuffer); if (cOut != mBuffer) { cOut.Write(decoded, 0, decoded.Length); cOut.Flush(); decoded = GetBufferContents(); } /* * RFC 5246 6.2.2. If the decompression function encounters a TLSCompressed.fragment that * would decompress to a length in excess of 2^14 bytes, it should report a fatal * decompression failure error. */ CheckLength(decoded.Length, mPlaintextLimit, AlertDescription.decompression_failure); /* * RFC 5246 6.2.1 Implementations MUST NOT send zero-length fragments of Handshake, Alert, * or ChangeCipherSpec content types. */ if (decoded.Length < 1 && type != ContentType.application_data) throw new TlsFatalAlert(AlertDescription.illegal_parameter); return decoded; } internal /*virtual*/ void WriteRecord(byte type, byte[] plaintext, int plaintextOffset, int plaintextLength) { // Never send anything until a valid ClientHello has been received if (mWriteVersion == null) return; /* * RFC 5246 6. Implementations MUST NOT send record types not defined in this document * unless negotiated by some extension. */ CheckType(type, AlertDescription.internal_error); /* * RFC 5246 6.2.1 The length should not exceed 2^14. */ CheckLength(plaintextLength, mPlaintextLimit, AlertDescription.internal_error); /* * RFC 5246 6.2.1 Implementations MUST NOT send zero-length fragments of Handshake, Alert, * or ChangeCipherSpec content types. */ if (plaintextLength < 1 && type != ContentType.application_data) throw new TlsFatalAlert(AlertDescription.internal_error); Stream cOut = mWriteCompression.Compress(mBuffer); long seqNo = mWriteSeqNo.NextValue(AlertDescription.internal_error); byte[] ciphertext; if (cOut == mBuffer) { ciphertext = mWriteCipher.EncodePlaintext(seqNo, type, plaintext, plaintextOffset, plaintextLength); } else { cOut.Write(plaintext, plaintextOffset, plaintextLength); cOut.Flush(); byte[] compressed = GetBufferContents(); /* * RFC 5246 6.2.2. Compression must be lossless and may not increase the content length * by more than 1024 bytes. */ CheckLength(compressed.Length, plaintextLength + 1024, AlertDescription.internal_error); ciphertext = mWriteCipher.EncodePlaintext(seqNo, type, compressed, 0, compressed.Length); } /* * RFC 5246 6.2.3. The length may not exceed 2^14 + 2048. */ CheckLength(ciphertext.Length, mCiphertextLimit, AlertDescription.internal_error); int recordLength = ciphertext.Length + TLS_HEADER_SIZE; byte[] record = BestHTTP.PlatformSupport.Memory.BufferPool.Get(recordLength, true); TlsUtilities.WriteUint8(type, record, TLS_HEADER_TYPE_OFFSET); TlsUtilities.WriteVersion(mWriteVersion, record, TLS_HEADER_VERSION_OFFSET); TlsUtilities.WriteUint16(ciphertext.Length, record, TLS_HEADER_LENGTH_OFFSET); Array.Copy(ciphertext, 0, record, TLS_HEADER_SIZE, ciphertext.Length); mOutput.Write(record, 0, recordLength); BestHTTP.PlatformSupport.Memory.BufferPool.Release(record); mOutput.Flush(); } internal /*virtual*/ void NotifyHelloComplete() { this.mHandshakeHash = mHandshakeHash.NotifyPrfDetermined(); } internal /*virtual*/ TlsHandshakeHash HandshakeHash { get { return mHandshakeHash; } } internal /*virtual*/ Stream HandshakeHashUpdater { get { return mHandshakeHashUpdater; } } internal /*virtual*/ TlsHandshakeHash PrepareToFinish() { TlsHandshakeHash result = mHandshakeHash; this.mHandshakeHash = mHandshakeHash.StopTracking(); return result; } internal /*virtual*/ void SafeClose() { try { BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Platform.Dispose(mInput); } catch (IOException) { } try { BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Platform.Dispose(mOutput); } catch (IOException) { } } internal /*virtual*/ void Flush() { mOutput.Flush(); } private byte[] GetBufferContents() { byte[] contents = mBuffer.ToArray(); mBuffer.SetLength(0); return contents; } private static void CheckType(byte type, byte alertDescription) { switch (type) { case ContentType.application_data: case ContentType.alert: case ContentType.change_cipher_spec: case ContentType.handshake: //case ContentType.heartbeat: break; default: throw new TlsFatalAlert(alertDescription); } } private static void CheckLength(int length, int limit, byte alertDescription) { if (length > limit) throw new TlsFatalAlert(alertDescription); } private class HandshakeHashUpdateStream : BaseOutputStream { private readonly RecordStream mOuter; public HandshakeHashUpdateStream(RecordStream mOuter) { this.mOuter = mOuter; } public override void Write(byte[] buf, int off, int len) { mOuter.mHandshakeHash.BlockUpdate(buf, off, len); } } private class SequenceNumber { private long value = 0L; private bool exhausted = false; internal long NextValue(byte alertDescription) { if (exhausted) { throw new TlsFatalAlert(alertDescription); } long result = value; if (++value == 0) { exhausted = true; } return result; } } } } #pragma warning restore #endif